import torch
import os
import numpy as np
import cv2
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
import torchvision
from xml.dom.minidom import parse
import utils
import transforms as T
from engine import train_one_epoch, evaluate
import xml.etree.cElementTree as ET
import collections
import pandas as pd
from torchvision.transforms import functional
import random
from aip import AipFace
import base64
import time
import tkinter as tk
from tkinter import filedialog
import shutil
import socket


""" 你的 APPID AK SK """
APP_ID = '22912073'
API_KEY = '3zbPMNiqWOrsD5BspgX1pBoR'
SECRET_KEY = 'nxSoiCZnSdOV9DVTPhHrlMvAYvTHZNa4'
# aip_face对象
aipFace = AipFace(APP_ID, API_KEY, SECRET_KEY)

faces_name_dict = {
    'Paulmale': 'Paul Male',
    'Danielmale': 'Daniel Male',
    'Fishermale': 'Fisher Male',
    'Jackmale': 'Jack Male',
    'Kevinmale': 'Kevin Male',
    'lilyfemale': 'lily Female',
    'Rosefemale': 'Rose Female',
    'Maryfemale': 'Mary Female',
    'Michaelmale': 'Michael Male',
    'stevenmale': 'Steven Male',
    'Jamesmale': 'James Male'
}


# 检测人脸，检测其中所有的人脸(性别)
def face_detection(img_path):
    faces = []

    # 读取图片，转base64
    filepath = img_path
    with open(filepath, "rb") as fp:
        base64_data = base64.b64encode(fp.read())
    image = str(base64_data, 'utf-8')
    imageType = "BASE64"

    # 配置参数
    options = {}
    options["face_field"] = "age,beauty,gender,glasses,quality,eye_status,face_type"
    options["max_face_num"] = 10
    options["face_type"] = "LIVE"
    # options['liveness_control'] = 'LOW'

    """ 带参数调用人脸搜索 """
    response = aipFace.detect(image, imageType, options)
    # print(response)
    if response['error_msg'] != 'SUCCESS':
        return faces

    result = response['result']
    face_list = result['face_list']
    # print(result)

    for idx in range(result['face_num']):
        face_dict = {}
        location = face_list[idx]['location']
        score = face_list[idx]['face_probability']
        gender = face_list[idx]['gender']['type']
        age = face_list[idx]['age']
        beauty = face_list[idx]['beauty']
        glasses = face_list[idx]['glasses']['type']
        angle = face_list[idx]['angle']

        face_dict['location'] = location
        face_dict['score'] = score
        face_dict['gender'] = gender
        face_dict['age'] = age
        face_dict['beauty'] = beauty
        face_dict['glassed'] = glasses
        face_dict['angle'] = angle

        if score >= 0.96:
            faces.append(face_dict)

    # time.sleep(0.5)
    return faces


# M:N搜索人脸，返回人脸列表和相关的信息
def predict_faces(img_path):
    faces = []

    # 读取图片，转base64
    filepath = img_path
    with open(filepath, "rb") as fp:
        base64_data = base64.b64encode(fp.read())
    image = str(base64_data, 'utf-8')
    imageType = "BASE64"

    # 配置参数
    groupIdList = 'main'
    options = {}
    options["max_face_num"] = 10
    options["match_threshold"] = 10
    options["quality_control"] = "NONE"
    options["liveness_control"] = "NONE"
    # options["user_id"] = "233451"
    options["max_user_num"] = 1

    """ 带参数调用人脸搜索 """
    response = aipFace.multiSearch(image, imageType, groupIdList, options)
    # print(response)
    if response['error_msg'] != 'SUCCESS':
        return faces
    # print(response)
    result = response['result']
    face_list = result['face_list']
    # print(result)

    for idx in range(result['face_num']):
        face_dict = {}
        location = face_list[idx]['location']
        score = face_list[idx]['user_list'][0]['score']
        user_id = face_list[idx]['user_list'][0]['user_info']
        user_id = user_id.split('.')[0]
        user_id = faces_name_dict[user_id]

        face_dict['location'] = location
        face_dict['user_id'] = user_id
        face_dict['score'] = score

        faces.append(face_dict)

    # time.sleep(0.5)
    return faces


def predict_once(model, img):
    # 将模型切换成预测模式
    model.eval()

    img_tensor = functional.to_tensor(img)

    with torch.no_grad():
        # 下方是原注释
        '''
        prediction形如：
        [{'boxes': tensor([[1492.6672,  238.4670, 1765.5385,  315.0320],
        [ 887.1390,  256.8106, 1154.6687,  330.2953]], device='cuda:0'), 
        'labels': tensor([1, 1], device='cuda:0'), 
        'scores': tensor([1.0000, 1.0000], device='cuda:0')}]
        '''
        device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        prediction = model([img_tensor.to(device)])

    # 先将预测结果打印出来看一下
    # print(prediction)

    return prediction


def random_color():
    b = random.randint(50, 200)
    g = random.randint(50, 200)
    r = random.randint(50, 200)
    return (b, g, r)


# 鼠标事件回调函数
def on_mouse(event, x, y, flags, param):
    if event == cv2.EVENT_LBUTTONDOWN:  # 鼠标左键被按下
        param['x1'].append(x)
        param['y1'].append(y)
        # print('down')
    elif event == cv2.EVENT_LBUTTONUP:  # 鼠标左键抬起
        param['x2'].append(x)
        param['y2'].append(y)
        # 画框计数
        # print('up')


def face_iou_cal(face1, face2):
    iou = 0
    location1 = face1['location']
    location2 = face2['location']

    left1 = location1['left']
    top1 = location1['top']
    width1 = location1['width']
    height1 = location1['height']

    left2 = location2['left']
    top2 = location2['top']
    width2 = location2['width']
    height2 = location2['height']

    raw_x11 = left1
    raw_y11 = top1
    raw_x12 = left1 + width1
    raw_y12 = top1 + height1

    raw_x21 = left2
    raw_y21 = top2
    raw_x22 = left2 + width2
    raw_y22 = top2 + height2

    x11 = min(raw_x11, raw_x21)
    y11 = min(raw_y11, raw_y21)
    x12 = min(raw_x12, raw_x22)
    y12 = min(raw_y12, raw_y22)

    x21 = max(raw_x11, raw_x21)
    y21 = max(raw_y11, raw_y21)
    x22 = max(raw_x12, raw_x22)
    y22 = max(raw_y12, raw_y22)
    # 没有重叠部分
    if x12 < x21 or y12 < y21:
        iou = 0
    else:
        h1 = y12 - y11
        w1 = x12 - x11
        h2 = y22 - y21
        w2 = x22 - x21
        h = y12 - y21
        w = x12 - x21
        area1 = h1 * w1
        area2 = h2 * w2
        area = h * w
        iou = area / (area1 + area2 - area)
    return iou


def main(args):
    # 命令行中取出参数
    # print(args)
    item_nums = args.num
    # print('num:', item_nums)

    root = r'dataset'
    prediction_root = os.path.join(root, 'prediction')
    prediction_txt_path = os.path.join(prediction_root, 'PREDICTION_DOCUMENT.txt')
    score_threshold = 0.5
    # model_selected = '59'
    # model_selected = 'last_model'
    model_selected = args.model
    temp_root = os.path.join(root, 'temp')
    if not os.path.exists(temp_root):
        os.mkdir(temp_root)
    else:
        for file_name in os.listdir(temp_root):
            os.remove(os.path.join(temp_root, file_name))

    # 准备prediction文档文件
    with open(prediction_txt_path, 'w', encoding='utf-8') as file:
        file.write('PREDICTION:\n')

    # 解析label_list文件
    with open(os.path.join(root, "label_list.txt"), 'r') as file:
        label_list = file.readlines()
    label_list = [label.rstrip() for label in label_list]  # 去掉空字符
    label_list = [label for label in label_list if label != '']  # 去掉空行

    # 自动设置类别数量
    num_classes = len(label_list)
    # print(label_list)
    # print(num_classes)

    # 加载模型
    print('Loading Selected Model...\n')
    model_path = os.path.join(root, 'models', model_selected + '.pkl')
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = torch.load(model_path, map_location=device)
    model.to(device)

    # 设置为检测模式
    model.eval()
    print('Loading Selected Model...Done\n')

    # 设置鼠标事件参数
    mouse_param = {}
    mouse_param['x1'] = []
    mouse_param['y1'] = []
    mouse_param['x2'] = []
    mouse_param['y2'] = []
    cv2.namedWindow('img')  # 先绘制一个空窗口
    cv2.setMouseCallback('img', on_mouse, mouse_param)

    # 模型加载完毕之后，循环打开图片
    img_name_count = 0
    img_root = os.path.join(root, 'test')

    while True:
        print('Please select A picture:\n')
        img_name_count += 1
        tk1 = tk.Tk()
        tk1.withdraw()
        img_path = filedialog.askopenfilename()
        src = img_path
        if src == '':
            break
        dst = os.path.join(temp_root, str(img_name_count) + 'origin_img.jpg')
        img_path = dst  # 这张图片路径
        shutil.copy(src, dst)

        # cv2读取图片
        src_img = cv2.imread(img_path)
        cv2.imshow('img', src_img)
        img = cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB)

        # 按下空格键开始识别
        print('Wait key to predict once...\n')
        key = cv2.waitKey()
        print('Predicting...\n')
        if key == 27:
            break
        else:  # 真正的图片检测部分
            # 准备图片列表
            imgs = []
            for idx in range(len(mouse_param['x1'])):  # 对于目标框数量
                t_img = img[mouse_param['y1'][idx]:mouse_param['y2'][idx], mouse_param['x1'][idx]:mouse_param['x2'][idx]]
                imgs.append(t_img)

            # 多次检测，多次画框
            # 从这里开始
            overall_boxes = []
            overall_scores = []
            overall_labels = []
            overall_items = []  # 储存所有物品信息的元组的列表
            item_num = item_nums[0]
            for img_idx in range(len(imgs)):
                # if img_idx > len(item_nums):  # 鼠标多画了就直接无视
                #     break
                img = imgs[img_idx]
                # item_num = item_nums[img_idx]
                # 用模型predict一次
                prediction = predict_once(model=model, img=img)
                # print(prediction[0]['labels'])
                # print(prediction[0]['scores'])

                # 解析鼠标框偏移量
                offset_x = mouse_param['x1'][img_idx]
                offset_y = mouse_param['y1'][img_idx]

                # 解析prediction
                boxes = prediction[0]['boxes']
                boxes = boxes.cpu().numpy().tolist()  # 转列表
                for idx in range(len(boxes)):  # 在这里就将boxes还原，加上鼠标偏移量，方便iou计算
                    boxes[idx][0] += offset_x
                    boxes[idx][1] += offset_y
                    boxes[idx][2] += offset_x
                    boxes[idx][3] += offset_y
                scores = prediction[0]['scores']
                scores = scores.cpu().numpy().tolist()  # 转列表
                labels = prediction[0]['labels']
                labels = [label_list[label] for label in labels]
                # print(labels)

                # 将这次的识别结果放进去
                for idx in range(len(boxes)):
                    overall_items.append(tuple([(boxes[idx]), scores[idx], labels[idx], offset_x, offset_y]))
                # print(overall_items)

            # 先对items列表对score排序
            overall_items = sorted(overall_items, key=lambda s: s[1], reverse=True)
            # print(overall_items)

            # 将各项东西取出
            labels = []
            boxes = []
            scores = []
            offset = []
            for idx in range(len(overall_items)):
                overall_item = overall_items[idx]
                boxes.append(overall_item[0])
                scores.append(overall_item[1])
                labels.append(overall_item[2])
                offset.append(tuple([overall_item[3], overall_item[4]]))

            # print(boxes)
            # print(labels)
            # print(scores)

            # 画框字典
            draw_list = {}
            draw_list['label'] = []
            draw_list['max_score'] = []
            draw_list['idx'] = []
            # draw_list['box'] = []

            # 搜索最佳识别结果
            for idx in range(len(boxes)):
                # print(idx)
                # draw_dict = {}
                # 添加新的物品
                if labels[idx] not in draw_list['label']:
                    draw_list['label'].append(labels[idx])
                    draw_list['max_score'].append(float(0))
                    draw_list['idx'].append(int(idx))
                    # draw_list['box'].append(boxes[idx])
                # 临时保存一下刚刚添加的，列表下标转换，将原始index转换成draw_list里面的index
                draw_idx = draw_list['label'].index(labels[idx])
                if scores[idx] > draw_list['max_score'][draw_idx]:
                    draw_list['max_score'][draw_idx] = float(scores[idx])
                    draw_list['idx'][draw_idx] = int(idx)
                    # draw_list['box'][draw_idx] = dict(boxes[idx])

            # print(draw_list)

            # 搜索完成之后，默认按照得分从高到低的顺序排列好了
            # 两两遍历计算IOU，舍弃IOU过高的两个框

            iou_pass_flag = 0
            while not iou_pass_flag:
                to_delete_idx_list = []  # 要删除的draw_idx
                # 检查前15个框的IOU
                for i in range(item_num):  # 从0到14循环，这是比较对象
                    if len(draw_list['idx']) <= item_num:
                        iou_pass_flag = 1
                        break
                    for j in range(item_num - 1, i,
                                   -1):  # 从14到i+1倒序循环，后面的一定比前面的score小，同时不会出现自己比自己。一旦IOU超标，说明后面的，也就是j需要删除
                        # print('i:', i)
                        # print('j:', j)
                        raw_idx1 = draw_list['idx'][i]
                        raw_idx2 = draw_list['idx'][j]
                        x11 = min(float(boxes[raw_idx1][0]), float(boxes[raw_idx2][0]))
                        y11 = min(float(boxes[raw_idx1][1]), float(boxes[raw_idx2][1]))
                        x12 = min(float(boxes[raw_idx1][2]), float(boxes[raw_idx2][2]))
                        y12 = min(float(boxes[raw_idx1][3]), float(boxes[raw_idx2][3]))
                        x21 = max(float(boxes[raw_idx1][0]), float(boxes[raw_idx2][0]))
                        y21 = max(float(boxes[raw_idx1][1]), float(boxes[raw_idx2][1]))
                        x22 = max(float(boxes[raw_idx1][2]), float(boxes[raw_idx2][2]))
                        y22 = max(float(boxes[raw_idx1][3]), float(boxes[raw_idx2][3]))
                        iou = 0
                        # 没有重叠部分
                        if x12 < x21 or y12 < y21:
                            iou = 0
                        else:
                            h1 = y12 - y11
                            w1 = x12 - x11
                            h2 = y22 - y21
                            w2 = x22 - x21
                            h = y12 - y21
                            w = x12 - x21
                            area1 = h1 * w1
                            area2 = h2 * w2
                            area = h * w
                            iou = area / (area1 + area2 - area)
                            # 如果iou超标
                            if iou > 0.5:
                                to_delete_idx_list.append(j)
                                pass
                        # print('iou:', iou)
                # 删除对应的元素
                to_delete_idx_list = list((set(to_delete_idx_list)))  # 删除重复元素
                to_delete_idx_list.sort()
                to_delete_idx_list.reverse()  # 先倒序
                # print(to_delete_idx_list)
                for idx in to_delete_idx_list:  # 删除
                    draw_list['label'].pop(idx)
                    draw_list['max_score'].pop(idx)
                    draw_list['idx'].pop(idx)
                    # draw_list['box'].pop(idx)
                # 如果不再有需要删除的box，则检查已完成，直接退出
                if not to_delete_idx_list:
                    iou_pass_flag = 1
                    break

            # print(draw_list)
            # 只画前n个框
            count = 0
            for idx in draw_list['idx']:
                count += 1
                if count > item_num:
                    break
                x1, y1, x2, y2 = boxes[idx][0], boxes[idx][1], boxes[idx][2], boxes[idx][3]
                # x1 += offset[idx][0]  # 画框时不再需要偏移量了
                # x2 += offset[idx][0]
                # y1 += offset[idx][1]
                # y2 += offset[idx][1]
                label = labels[idx]
                color = random_color()
                cv2.rectangle(src_img, (int(x1), int(y1)), (int(x2), int(y2)), color=color, thickness=2)
                cv2.rectangle(src_img, (int(x1), int(y2)), (int(x2), int(y2 + 20)), color=color, thickness=-1)
                cv2.putText(src_img, labels[idx], (int(x1), int(y2 + 16)), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                            fontScale=0.5,
                            color=(0, 0, 0))
                # cv2.putText(src_img, str(round(float(scores[idx]), 2)), (int(x1), int(y2)),
                #             fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=1,
                #             color=color, thickness=2)

                # 写入txt文件
                with open(prediction_txt_path, 'a', encoding='utf-8') as f:
                    f.write('boxes:\t')
                    f.write('x1: ' + str(int(x1)) + '\t')
                    f.write('y1: ' + str(int(y1)) + '\t')
                    f.write('x2: ' + str(int(x2)) + '\t')
                    f.write('y2: ' + str(int(y2)) + '\t')
                    f.write('label: ' + str(label) + '\t')
                    f.write('\n')

            # 每次完整的识别之后，清空鼠标参数
            mouse_param['x1'].clear()
            mouse_param['x2'].clear()
            mouse_param['y1'].clear()
            mouse_param['y2'].clear()

            # 人脸检测
            faces = []
            times_idx = 0
            next_face_img_path = img_path
            while True:
                times_idx += 1
                # next_path已经准备好，因此直接检测
                time_start = time.time()
                faces_once = predict_faces(next_face_img_path)

                # 判断是否没有人脸了
                if not faces_once:  # 如果检测不到人脸了，就退出
                    break

                # 拼接人脸列表
                faces.extend(faces_once)

                # 涂色
                t_face_img = cv2.imread(next_face_img_path)
                for face in faces_once:
                    location = face['location']
                    left = location['left']
                    top = location['top']
                    width = location['width']
                    height = location['height']
                    cv2.rectangle(t_face_img, (int(left - 5), int(top - 5)),
                                  (int(left + width + 5), int(top + height + 5)),
                                  color=(0, 0, 0),
                                  thickness=-1)  # 涂色

                # 保存新的图片，准备下一张图片的路径
                new_face_img_name = 'face_search_temp' + str(times_idx) + '.jpg'
                next_face_img_path = os.path.join(temp_root, new_face_img_name)
                cv2.imwrite(next_face_img_path, t_face_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

                # 记录时间，两次人脸检测
                time_end = time.time()
                if time_end - time_start < 0.5:
                    time.sleep(time_end - time_start)

            # 人脸搜索阈值检测，如果匹配得分过低就舍弃这个人脸
            face_search_th = 65
            # faces = filter(lambda x: x['score'] < face_search_th, faces)  # 使用filter函数，筛除列表中不足65分的人脸
            for idx in range(len(faces) - 1, -1, -1):  # 倒序删除
                if faces[idx]['score'] < face_search_th:
                    faces.pop(idx)

            # 人脸性别识别
            time.sleep(0.5)  # 直接补0.5s的时间，防止人脸识别冲突
            gender_faces = []
            times_idx = 0
            next_face_img_path = img_path
            while True:
                # print(next_face_img_path)
                times_idx += 1
                # next_path已经准备好，因此直接检测
                time_start = time.time()
                faces_once = face_detection(next_face_img_path)

                # 判断是否没有人脸了
                if not faces_once:  # 如果检测不到人脸了，就退出
                    break

                # 拼接人脸列表
                gender_faces.extend(faces_once)

                # 涂色
                t_face_img = cv2.imread(next_face_img_path)
                for face in faces_once:
                    location = face['location']
                    left = location['left']
                    top = location['top']
                    width = location['width']
                    height = location['height']
                    cv2.rectangle(t_face_img, (int(left - 15), int(top - 15)),
                                  (int(left + width + 15), int(top + height + 15)),
                                  color=(0, 0, 0),
                                  thickness=-1)  # 涂色

                # 保存新的图片，准备下一张图片的路径
                new_face_img_name = 'face_detection_temp' + str(times_idx) + '.jpg'
                next_face_img_path = os.path.join(temp_root, new_face_img_name)
                cv2.imwrite(next_face_img_path, t_face_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

                # 记录时间，两次人脸检测
                time_end = time.time()
                if time_end - time_start < 0.5:
                    time.sleep(time_end - time_start)

            # 将性别识别出来的人脸与之前的逐一比对，然后删去多余的
            face_iou_th = 0.9
            for idx in range(len(gender_faces)-1, -1, -1):  # 倒序循环解决列表删除问题
                gender_face = gender_faces[idx]
                for face in faces:
                    pass
                    iou = face_iou_cal(face, gender_face)
                    # print(iou)
                    if face_iou_cal(face, gender_face) > face_iou_th:  # 如果iou超过阈值，则舍弃gender-face
                        gender_faces.pop(idx)
            # print(faces)

            # 画人脸检测框(gender)
            for gender_face in gender_faces:
                # 取信息
                location = gender_face['location']
                score = gender_face['score']
                gender = gender_face['gender']
                if gender == 'male':
                    gender = 'Male'
                if gender == 'female':
                    gender = 'Female'

                left = location['left']
                top = location['top']
                width = location['width']
                height = location['height']

                color = random_color()
                cv2.rectangle(src_img, (int(left), int(top)), (int(left + width), int(top + height)), color=color,
                              thickness=2)  # boxes
                cv2.rectangle(src_img, (int(left), int(top + height)), (int(left + width), int(top + height + 40)),
                              color=color,
                              thickness=-1)  # 底色
                cv2.putText(src_img, 'Stranger', (int(left), int(top + height + 16)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                            fontScale=0.5, color=(0, 0, 0), thickness=1)  # user_id
                cv2.putText(src_img, gender, (int(left), int(top + height + 32)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                            fontScale=0.5, color=(0, 0, 0), thickness=1)  # gender

                # 保存txt
                with open(prediction_txt_path, 'a', encoding='utf-8') as f:
                    f.write('boxes:\t')
                    f.write('x1: ' + str(int(left)) + '\t')
                    f.write('y1: ' + str(int(top)) + '\t')
                    f.write('x2: ' + str(int(left + width)) + '\t')
                    f.write('y2: ' + str(int(top + height)) + '\t')
                    f.write(' name: ' + 'Stranger' + '\t')
                    f.write(' gender: ' + gender)
                    f.write('\n')

            # 画人脸搜索框(M:N)
            for face in faces:
                # print('face')
                # 取信息
                location = face['location']
                user_id = face['user_id']
                gender = user_id.split(' ')[-1]
                user_id = user_id.split(' ')[0]

                score = face['score']

                left = location['left']
                top = location['top']
                width = location['width']
                height = location['height']

                color = random_color()
                cv2.rectangle(src_img, (int(left), int(top)), (int(left + width), int(top + height)), color=color,
                              thickness=2)  # boxes
                cv2.rectangle(src_img, (int(left), int(top + height)), (int(left + width), int(top + height + 40)),
                              color=color,
                              thickness=-1)  # 底色
                cv2.putText(src_img, user_id, (int(left), int(top + height + 16)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                            fontScale=0.5, color=(0, 0, 0), thickness=1)  # user_id
                cv2.putText(src_img, gender, (int(left), int(top + height + 32)), fontFace=cv2.FONT_HERSHEY_TRIPLEX,
                            fontScale=0.5, color=(0, 0, 0), thickness=1)  # user_id
                # cv2.putText(src_img, str(round(score, 2)), (int(left), int(top+height)),
                #             fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1,
                #             color=color, thickness=2)  # score

                # 保存txt
                with open(prediction_txt_path, 'a', encoding='utf-8') as f:
                    f.write('boxes:\t')
                    f.write('x1: ' + str(int(left)) + '\t')
                    f.write('y1: ' + str(int(top)) + '\t')
                    f.write('x2: ' + str(int(left + width)) + '\t')
                    f.write('y2: ' + str(int(top + height)) + '\t')
                    f.write('name: ' + user_id + '\t')
                    f.write('gender: ' + gender + '\t')
                    f.write('\n')

            # 显示时间 队伍名称
            cv2.putText(src_img, time.strftime('%F--%Hh%Mm%Ss'), (0, 25), fontFace=cv2.FONT_HERSHEY_COMPLEX,
                        fontScale=1, color=(255, 0, 0), thickness=1)
            # cv2.putText(src_img, 'NNU--识别1队', (0, 50), fontFace=cv2.FONT_HERSHEY_COMPLEX,
            #             fontScale=1, color=(255, 0, 0), thickness=1)
            font_path = 'simsun.ttc'
            chinese_font = ImageFont.truetype(font_path, 40)
            team_name = 'NNU识别1队'
            img_pil = Image.fromarray(src_img)
            draw = ImageDraw.Draw(img_pil)
            draw.text((0, 50), team_name, font=chinese_font, fill=(255, 0, 0, 0))
            src_img = np.array(img_pil)

            # 画一下物品框的位置
            # cv2.rectangle(src_img, (int(mouse_param['x1']), int(mouse_param['y1'])),
            #               (int(mouse_param['x2']), int(mouse_param['y2'])), color=color, thickness=1)

        # 真正显示图片
        # src_img = cv2.resize(src_img, (1280, 720))
        print('Prediction Done!\n')
        cv2.imshow('img', src_img)
        # 保存图片

        img_save_path = os.path.join(prediction_root, str(img_name_count) + 'PREDICTION_img.jpg')
        print('Save to' + img_save_path + '\n')
        print('Save to' + prediction_txt_path + '\n')
        cv2.imwrite(img_save_path, src_img, [int(cv2.IMWRITE_PNG_COMPRESSION), 0])

        # 按下空格继续识别下一张图片
        print('Press any key to continue\n')
        print('Press ESC to exit\n')
        key = cv2.waitKey()
        if key == 27:
            break

    # 主函数退出的时候，销毁所有窗口
    cv2.destroyAllWindows()


if __name__ == '__main__':
    import argparse

    parser = argparse.ArgumentParser(description=__doc__)

    parser.add_argument('-n', '--num', default=[0, 0, 0, 0, 0, 0, 0, 0], type=int, metavar='N1, N2, N3...',
                        help='num of epochs', required=True, nargs='+')
    parser.add_argument('-m', '--model', help='model_name', required=True)  # 模型名称
    # parser.add_argument('-f', '--face', default=True, type=bool, help='model_name')  # 模型名称


    # parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset')
    # parser.add_argument('--dataset', default='coco', help='dataset')
    # parser.add_argument('--model', default='maskrcnn_resnet50_fpn', help='model')
    # parser.add_argument('--device', default='cuda', help='device')
    # parser.add_argument('-b', '--batch-size', default=2, type=int,
    #                     help='images per gpu, the total batch size is $NGPU x batch_size')
    # parser.add_argument('--epochs', default=26, type=int, metavar='N',
    #                     help='number of total epochs to run')
    # parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
    #                     help='number of data loading workers (default: 4)')
    # parser.add_argument('--lr', default=0.02, type=float,
    #                     help='initial learning rate, 0.02 is the default value for training '
    #                     'on 8 gpus and 2 images_per_gpu')
    # parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
    #                     help='momentum')
    # parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
    #                     metavar='W', help='weight decay (default: 1e-4)',
    #                     dest='weight_decay')
    # parser.add_argument('--lr-step-size', default=8, type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-steps', default=[16, 22], nargs='+', type=int, help='decrease lr every step-size epochs')
    # parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma')
    # parser.add_argument('--print-freq', default=20, type=int, help='print frequency')
    # parser.add_argument('--output-dir', default='.', help='path where to save')
    # parser.add_argument('--resume', default='', help='resume from checkpoint')
    # parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
    # parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
    # parser.add_argument(
    #     "--test-only",
    #     dest="test_only",
    #     help="Only test the model",
    #     action="store_true",
    # )
    # parser.add_argument(
    #     "--pretrained",
    #     dest="pretrained",
    #     help="Use pre-trained models from the modelzoo",
    #     action="store_true",
    # )

    # distributed training parameters
    # parser.add_argument('--world-size', default=1, type=int,
    #                     help='number of distributed processes')
    # parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training')

    args = parser.parse_args()
    # if args.output_dir:
    #     utils.mkdir(args.output_dir)

    # 将这些命令行参数传入主函数中运行
    main(args)
